The purpose of this notebook is to infer the rate at which confirmed cases of COVID-19 are growing (or were growing) in various countries.
The notebook pulls data from the Johns Hopkins Data Repository of global Coronavirus COVID-19 cases, and then does the following things:
We then repeat these steps for US states.
The notebook is updated approximately daily.
For a great primer on exponential and logistic growth, watch this video.
The growth rate (and the doubling time) changes with time. As the exponential curve eventually turns into a logistic curve, the growth rate will shrink to zero (& the doubling time will consequently increase). So it's not a good idea to extrapolate trends far into the future based on current growth rates or doubling times.
The confirmed cases reported by each country are not the number of infections in each country, only those that have tested positive.
The doubling time calculated here measures the growth of cumulative confirmed cases, which is different from the growth of infections. For example, if a country suddenly ramps up testing, then the number of confirmed cases will rapidly rise, but infections may not be rising as the same rate.
The doubling times inferred from the curve fits are not necessarily the current or most recent doubling times:
# Now
! date
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from termcolor import colored, cprint
import plotly.graph_objects as go
#import plotly.offline as offline
#offline.init_notebook_mode(connected=True)
datadir = 'https://github.com/CSSEGISandData/COVID-19/raw/master/csse_covid_19_data/csse_covid_19_time_series/'
df = pd.read_csv( datadir + 'time_series_covid19_confirmed_global.csv')
#df = pd.read_csv( datadir + 'time_series_covid19_deaths_global.csv')
cases = df.iloc[:,[1,-1]].groupby('Country/Region').sum()
mostrecentdate = cases.columns[0]
print('\nTotal number of cases (in countries with at least 100 cases) as of', mostrecentdate)
cases = cases.sort_values(by = mostrecentdate, ascending = False)
cases = cases[cases[mostrecentdate] >= 100]
cases.head()
def logistic(t, a, b, c, d):
return c + (d - c)/(1 + a * np.exp(- b * t))
def exponential(t, a, b, c):
return a * np.exp(b * t) + c
def plotCases(dataframe, column, country, maxfev=100000, use_plotly=False):
#def plotCases(dataframe, column, country, maxfev=1):
co = dataframe[dataframe[column] == country].iloc[:,4:].T.sum(axis = 1)
co = pd.DataFrame(co)
co.columns = ['Cases']
co = co.loc[co['Cases'] > 0]
y = np.array(co['Cases'])
x = np.arange(y.size)
recentdbltime = float('NaN')
if len(y) >= 7:
current = y[-1]
lastweek = y[-8]
if current > lastweek:
print('\n** Based on Most Recent Week of Data **\n')
print('\tConfirmed cases on',co.index[-1],'\t',current)
print('\tConfirmed cases on',co.index[-8],'\t',lastweek)
ratio = current/lastweek
print('\tRatio:',round(ratio,2))
print('\tWeekly increase:',round( 100 * (ratio - 1), 1),'%')
dailypercentchange = round( 100 * (pow(ratio, 1/7) - 1), 1)
print('\tDaily increase:', dailypercentchange, '% per day')
recentdbltime = round( 7 * np.log(2) / np.log(ratio), 1)
print('\tDoubling Time (represents recent growth):',recentdbltime,'days')
if(use_plotly is False):
plt.figure(figsize=(10,5))
plt.plot(x, y, 'ko', label="Original Data")
else:
fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode='markers', name='Original Data'))
logisticworked = False
exponentialworked = False
try:
lpopt, lpcov = curve_fit(logistic, x, y, maxfev=maxfev)
lerror = np.sqrt(np.diag(lpcov))
# for logistic curve at half maximum, slope = growth rate/2. so doubling time = ln(2) / (growth rate/2)
ldoubletime = np.log(2)/(lpopt[1]/2)
# standard error
ldoubletimeerror = 1.96 * ldoubletime * np.abs(lerror[1]/lpopt[1])
# calculate R^2
residuals = y - logistic(x, *lpopt)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((y - np.mean(y))**2)
logisticr2 = 1 - (ss_res / ss_tot)
if logisticr2 > 0.95:
if(use_plotly is False):
plt.plot(x, logistic(x, *lpopt), 'b--', label="Logistic Curve Fit")
else:
fig.add_trace(go.Scatter(x=x, y=logistic(x, *lpopt), mode='lines', line=dict(dash='dot'), name="Logistic Curve Fit") )
print('\n** Based on Logistic Fit**\n')
print('\tR^2:', logisticr2)
print('\tDoubling Time (during middle of growth): ', round(ldoubletime,2), '(±', round(ldoubletimeerror,2),') days')
print("\tparam: ", lpopt)
logisticworked = True
else:
print("\n logistic R^2 ", logisticr2)
except Exception as ex:
cprint('\nException in logstic process ', 'red')
cprint(type(ex), 'red')
cprint(ex, 'red')
try:
epopt, epcov = curve_fit(exponential, x, y, bounds=([0,0,-100],[100,0.9,100]), maxfev=maxfev)
eerror = np.sqrt(np.diag(epcov))
# for exponential curve, slope = growth rate. so doubling time = ln(2) / growth rate
edoubletime = np.log(2)/epopt[1]
# standard error
edoubletimeerror = 1.96 * edoubletime * np.abs(eerror[1]/epopt[1])
# calculate R^2
residuals = y - exponential(x, *epopt)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((y - np.mean(y))**2)
expr2 = 1 - (ss_res / ss_tot)
if expr2 > 0.95:
if(use_plotly is False):
plt.plot(x, exponential(x, *epopt), 'r--', label="Exponential Curve Fit")
else:
fig.add_trace(go.Scatter(x=x, y=exponential(x, *epopt), mode='lines', line=dict(dash='dot'), name="Exponential Curve Fit"))
print('\n** Based on Exponential Fit **\n')
print('\tR^2:', expr2)
print('\tDoubling Time (represents overall growth): ', round(edoubletime,2), '(±', round(edoubletimeerror,2),') days')
print("\tparam: ", epopt)
exponentialworked = True
else:
print("\n exponential R^2 ", expr2)
except Exception as ex:
cprint('\nException in exponential process ', 'red')
cprint(type(ex), 'red')
cprint(ex, 'red')
if(use_plotly == False):
plt.title(country + ' Cumulative COVID-19 Cases. (Updated on '+mostrecentdate+')', fontsize="x-large")
plt.xlabel('Days', fontsize="x-large")
plt.ylabel('Total Cases', fontsize="x-large")
plt.legend(fontsize="x-large")
plt.show()
else:
fig.update_layout(title=country + ' Cumulative COVID-19 Cases. (Updated on '+mostrecentdate+')'
, xaxis_title='Days'
, yaxis_title='Total Cases'
, width=900, height=700, autosize=False
#,paper_bgcolor='black'
)
fig.show()
if logisticworked and exponentialworked:
if round(logisticr2,2) > round(expr2,2):
return [ldoubletime, ldoubletimeerror, recentdbltime]
else:
return [edoubletime, edoubletimeerror, recentdbltime]
if logisticworked:
return [ldoubletime, ldoubletimeerror, recentdbltime]
if exponentialworked:
return [edoubletime, edoubletimeerror, recentdbltime]
else:
return [float('NaN'), float('NaN'), recentdbltime]
topcountries = cases.index
inferreddoublingtime = []
recentdoublingtime = []
errors = []
countries = []
print('\n')
cnames = topcountries.values
for c in cnames:
print(c)
a = plotCases(df, 'Country/Region', c)
if a:
countries.append(c)
inferreddoublingtime.append(a[0])
errors.append(a[1])
recentdoublingtime.append(a[2])
print('\n')
d = {'Countries': countries, 'Inferred Doubling Time': inferreddoublingtime, '95%CI': errors, 'Recent Doubling Time': recentdoublingtime}
print('\nInferred Doubling Times are inferred using curve fits.')
print('Recent Doubling Times are calculated using the most recent week of data.')
print('Shorter doubling time = faster growth, longer doubling time = slower growth.')
print('\n')
print(pd.DataFrame(data=d).iloc[:,[1,2,3]].round(1))
print('\n')
dt = pd.DataFrame(data = d)
dt = dt[dt['Inferred Doubling Time'] < 100]
dt.plot.bar(x = 'Countries', y = 'Inferred Doubling Time', yerr='95%CI', legend=False,figsize=(15,7.5), fontsize="x-large", capsize=4);
plt.axhline(y=1, linestyle='--')
plt.axhline(y=3, linestyle='--')
plt.axhline(y=5, linestyle='--')
plt.ylabel('Doubling Time (Days)', fontsize="x-large")
plt.xlabel('Countries', fontsize="x-large")
plt.title('Inferred Doubling Time of Cumulative COVID-19 Cases. Last update: ' + mostrecentdate, fontsize="x-large")
plt.show()
print('\n')
dt = pd.DataFrame(data = d)
dt = dt[dt['Inferred Doubling Time'] < 10]
dt.plot.bar(x = 'Countries', y = 'Inferred Doubling Time', yerr='95%CI', legend=False,figsize=(15,7.5), fontsize="x-large", capsize=4);
plt.ylabel('Doubling Time (Days)', fontsize="x-large")
plt.xlabel('Countries', fontsize="x-large")
plt.axhline(y=1, linestyle='--')
plt.axhline(y=3, linestyle='--')
plt.axhline(y=5, linestyle='--')
plt.title('Inferred Doubling Time of Cumulative COVID-19 Cases. Last update: ' + mostrecentdate, fontsize="x-large")
plt.show()
err = pd.DataFrame([errors,[float('NaN') for e in errors]]).T
err.index=countries
err.columns = ['Inferred Doubling Time', 'Recent Doubling Time']
print('\n')
dt = pd.DataFrame({'Inferred Doubling Time': inferreddoublingtime,'Recent Doubling Time': recentdoublingtime}, index=countries)
dt = dt[dt['Recent Doubling Time'] < 100]
dt.plot.bar(figsize=(15,7.5), fontsize="x-large", yerr=err, capsize=4)
plt.ylabel('Doubling Time (Days)', fontsize="x-large")
plt.xlabel('Countries', fontsize="x-large")
plt.axhline(y=1, linestyle='--')
plt.axhline(y=3, linestyle='--')
plt.axhline(y=5, linestyle='--')
plt.title('Doubling Time of Cumulative COVID-19 Cases. Last update: ' + mostrecentdate, fontsize="x-large")
plt.show()
print('\n')
dt = pd.DataFrame({'Inferred Doubling Time': inferreddoublingtime,'Recent Doubling Time': recentdoublingtime}, index=countries)
dt = dt[dt['Recent Doubling Time'] < 10]
dt.plot.bar(figsize=(15,7.5), fontsize="x-large", yerr=err, capsize=4)
plt.ylabel('Doubling Time (Days)', fontsize="x-large")
plt.xlabel('Countries', fontsize="x-large")
plt.axhline(y=1, linestyle='--')
plt.axhline(y=3, linestyle='--')
plt.axhline(y=5, linestyle='--')
plt.title('Doubling Time of Cumulative COVID-19 Cases. Last update: ' + mostrecentdate, fontsize="x-large")
plt.show()